Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RAY AIR][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as a working example #30492

Merged
merged 11 commits into from
Nov 28, 2022

Conversation

dmatrix
Copy link
Contributor

@dmatrix dmatrix commented Nov 18, 2022

Signed-off-by: Jules Damji [email protected]

  • Rewrote the code snippet as it was not working
  • Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI
  • Ignore the output since we get loads of output from the three workers
  • Assert that the loss converges with the training data within specified epochs
  • Tested code end-to-end with this script

Checks

  • [x ] I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • [x ] I've run scripts/format.sh to lint the changes in this PR.

…nd use testcode, and ignore long output from train

Signed-off-by: Jules Damji <[email protected]>
@dmatrix dmatrix changed the title [RAY AIR][TRAIN][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as a working example [RAY AIR][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as a working example Nov 19, 2022
@@ -22,13 +22,14 @@ class TorchTrainer(DataParallelTrainer):
The ``train_loop_per_worker`` function is expected to take in either 0 or 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paragraph above this, it says "already" twice in the sentence -- it would be great to also fix this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Fixed


from typing import Dict
def train_loop_per_worker(config: Dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would have a bit more typing like Dict[str, Any] (not sure what exactly the format here is) and also link to the format of the dict if possible :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can add some typing.

@@ -45,32 +46,33 @@ def train_loop_per_worker(config: Dict):
Inside the ``train_loop_per_worker`` function, you can use any of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there would also be an example for the above paragraph somewhere, we can feel free to do that in another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(You can discard this, I saw the usage is already shown in the example below -- maybe add (see example below).


def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it was better without this line but if you prefer feel free to keep it :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the code is incomplete session.report(...) and session.get_checkpoint(), nice to explain it with a comment

session.report(...)

# Returns dict of last saved checkpoint.
# Session returns dict of last saved checkpoint.
Copy link
Contributor

@pcmoritz pcmoritz Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say Get dict of last saved checkpoint. here (same below)? Session returns is a little confusing I think, since technically session is a python module here and it doesn't return anything :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Get x makes sense then returns, since it's an explicit method call to sessin.get_xxx

self.layer2 = nn.Linear(layer_size, output_size)

def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either keep the ReLU layer here or have only one linear layer -- composing two linear layers doesn't do anything and it would likely be confusing to users :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping ReLU does not make sense. Why add non-linearity to a linear data relationship With ReLU the model's does not converge, it goes on like a seesaw. Having two linear layers is not uncommon. We can put in a comment, you can also use one layer if you relationship between your data and outcome (target) is linear.


# Report and record metrics, checkpoint model at end of each
# epoch
session.report({"loss": loss.item(), "epoch": epoch},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing since epoch is both here and below, @amogkam can you recommend how to do this? Most users will follow the example, so we should make sure we do this well :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One is reporting the loss per epoch as metrics, the other is there for checkpoint per epoch. Nice to have that metrics per epoch. If @amogkam feels strongly that we should not include "epoch" in the metrics to report, then I can remove that entity.

result = trainer.fit()

# Get the loss metric from TorchCheckpoint tuple data dictionary
best_checkpoint_loss = result.metrics['loss']
# print(f"best loss: {best_checkpoint_loss:.4f}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you remove the # here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah a bit redundant since the code is self-explanatory.

train_loop_per_worker: The training function to execute.
This can either take in no arguments or a ``config`` dict.
train_loop_config: Configurations to pass into
train_loop_config: Configurations to pass into
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation should be kept here, right? Otherwise it won't render correctly :)

Copy link
Contributor Author

@dmatrix dmatrix Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Args <parameter_name>: should be indented on the same level. That is:

Args:
        arg_1:  ...
        arg_2: ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen Shot 2022-11-21 at 3 25 39 PM

They seem to render properlyl

Copy link
Contributor

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, this is great! There are a few small comments you should address before merging :)

@dmatrix
Copy link
Contributor Author

dmatrix commented Nov 27, 2022

@pcmoritz can we merge this and 30637 if you don't see any issues.

Signed-off-by: Philipp Moritz <[email protected]>
Signed-off-by: Philipp Moritz <[email protected]>
Signed-off-by: Philipp Moritz <[email protected]>
@pcmoritz
Copy link
Contributor

Looks like some of the hugging face servers are down, which is independent of this PR, we can merge it after the tests ran.

@pcmoritz
Copy link
Contributor

Signed-off-by: Jules Damji [email protected]

  • Rewrote the code snippet as it was not working

  • Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI

  • Ignore the output since we get loads of output from the three workers

  • Assert that the loss converges with the training data within specified epochs

  • Tested code end-to-end with this script

@pcmoritz pcmoritz merged commit 19aadd4 into master Nov 28, 2022
@pcmoritz pcmoritz deleted the br_jsd_improve_code_snippets_pr_3 branch November 28, 2022 07:18
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
… a working example (ray-project#30492)

Signed-off-by: Jules Damji [email protected]

- Rewrote the code snippet as it was not working

- Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI

- Ignore the output since we get loads of output from the three workers

- Assert that the loss converges with the training data within specified epochs

- Tested code end-to-end

Signed-off-by: Weichen Xu <[email protected]>
tamohannes pushed a commit to ju2ez/ray that referenced this pull request Jan 25, 2023
… a working example (ray-project#30492)

Signed-off-by: Jules Damji [email protected]

- Rewrote the code snippet as it was not working

- Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI

- Ignore the output since we get loads of output from the three workers

- Assert that the loss converges with the training data within specified epochs

- Tested code end-to-end

Signed-off-by: tmynn <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants